In this notebook, I will do prior predictive checks to make sure the values generated are reasonable and not too extreme.
Then, I will perform simulation-based calibration to evaluate whether the model can recover the parameters drawn from the prior.
Imports¶
In [83]:
import numpy as np
import pymc as pm
from bokeh.plotting import figure, show
from bokeh.layouts import gridplot
from bokeh.models import ColumnDataSource, Span
from bokeh.io import output_notebook
output_notebook()
from tqdm import tqdm
Model description¶
With biological replicates¶
$$ \underset{\textcolor{purple}{\text{posterior}}}{\pi \left( \underline{\alpha}, \underline{b^{(x)}},\underline{b^{(s)}}, \mu_b, \sigma_b, \underline{x^{mRNA}}, X^{mRNA} | \underline{x^{seq}}, \underline{s^{seq}}, \underline{s^{spike}} \right)} \propto$$$$\underset{\textcolor{purple}{\text{spike-in likelihood}}}{\pi \left(\underline{s^{spike}}, \underline{s^{seq}} | \underline{b^{(s)}} \right)} \underset{\textcolor{purple}{\text{priors}}}{\pi \left(\underline{b^{(s)}} | \mu_b, \sigma_b \right) \pi \left(\mu_b \right) \pi \left(\sigma_b \right)}\times$$$$\underset{\textcolor{purple}{\text{txtome likelihood}}}{\pi \left(\underline{x^{seq}} | \underline{x^{mRNA}}, \underline{b^{(x)}} \right)} \underset{\textcolor{purple}{\text{priors}}}{\pi \left(\underline{b^{(x)}} | \mu_b, \sigma_b \right) }\times$$$$ \underset{\textcolor{purple}{\text{idealized likelihood}}}{\pi \left( \underline{x^{mRNA}} | X^{mRNA}, \underline{\alpha}\right)} \underset{\textcolor{purple}{\text{priors}}}{\pi \left(X^{mRNA} \right) \pi \left(\underline{\alpha} | \underline{\alpha_{global}} \right) \pi \left( \underline{\alpha_{global}}\right)}$$In [ ]:
In [ ]:
Prior predictive checks¶
Does my prior generate plausible data? If I believe this model and these priors, what kind of data do I expect to see?
I'm specifically checking for:
- Values of x_seq_obs up to 620,000 (largest value in my un-normalized dataset) but no extreme values
- Alpha distribution is generally uniform and not dominated by 0 or 1
- Does b_x have extreme values?
In [113]:
R = 2 # num reps
G = 3 # num genes
K = 4 # num spike-in species
s_spike_true = np.array([[2, 20, 1000, 2000], [2, 20, 1000, 2000]])
Define model
In [114]:
with pm.Model() as model:
# spike-in/transcriptome priors
mu_b = pm.HalfNormal("mu_b", sigma = 1)
#sigma_b = pm.LogNormal("sigma_b", mu = 0.1, sigma = 0.3)
b_x = pm.HalfNormal('b_x', sigma = mu_b, shape = (R,G))
b_s = pm.HalfNormal("b_s", sigma = mu_b, shape = (R,K))
# idealized priors
# hyperprior global alpha
alpha_global = pm.Dirichlet('alpha_global', a = np.ones(G), shape = G)
X_mrna = pm.LogNormal("X_mrna", mu = np.log(50000), sigma = 1, shape = R)
conc = 500 # to give flexibility
alpha = pm.Dirichlet('alpha', a = conc * alpha_global, shape = (R,G))
# likelihood spike ins
s_seq_obs = pm.Poisson("s_seq_obs", mu = s_spike_true * b_s)
# likelihood transcriptome
x_mrna = X_mrna[:, None] * alpha
x_seq_obs = pm.Poisson("x_seq_obs", mu = x_mrna * b_x)
prior_pred = pm.sample_prior_predictive(samples = 1000)
Sampling: [X_mrna, alpha, alpha_global, b_s, b_x, mu_b, s_seq_obs, x_seq_obs]
In [126]:
with pm.Model() as model:
# spike-in/transcriptome priors
mu_b = pm.HalfNormal("mu_b", sigma = 1)
sigma_b = pm.LogNormal("sigma_b", mu = 0.1, sigma = 0.1)
b_x = pm.LogNormal('b_x', mu = mu_b, sigma = sigma_b, shape = (R,G))
b_s = pm.LogNormal("b_s", mu = mu_b, sigma = sigma_b, shape = (R,K))
# idealized priors
# hyperprior global alpha
alpha_global = pm.Dirichlet('alpha_global', a = np.ones(G), shape = G)
X_mrna = pm.LogNormal("X_mrna", mu = np.log(2500), sigma = 1, shape = R)
conc = 500 # to give flexibility
alpha = pm.Dirichlet('alpha', a = conc * alpha_global, shape = (R,G))
# likelihood spike ins
s_seq_obs = pm.Poisson("s_seq_obs", mu = s_spike_true * b_s)
# likelihood transcriptome
#x_mrna = X_mrna[:, None] * alpha
x_mrna = pm.Multinomial("x_mrna", n = X_mrna, p = alpha)
x_seq_obs = pm.Poisson("x_seq_obs", mu = x_mrna * b_x)
prior_pred = pm.sample_prior_predictive(samples = 1000)
Sampling: [X_mrna, alpha, alpha_global, b_s, b_x, mu_b, s_seq_obs, sigma_b, x_mrna, x_seq_obs]
Take a look
In [133]:
param = 'x_seq_obs' # x_seq_obs, s_seq_obs, alpha, b_x, b_s
In [134]:
prior_samples = prior_pred.prior[param].values
plots = []
# Loop through each replicate (2 reps)
for rep in range(R):
for gene_ind in range(G):
prior_gene = prior_samples[0, :, rep, gene_ind]
# set up histogram
p1 = figure(title=f"Prior Predictive Check, {param}, Rep {rep+1}, Gene {gene_ind}",
x_axis_label=param,y_axis_label="Count",
width=400,height=300,background_fill_color="#fafafa")
hist, edges = np.histogram(prior_gene, bins=50)
p1.quad(
top=hist,bottom=0, left=edges[:-1],right=edges[1:],
fill_color="skyblue",line_color="black",legend_label="1000 random samples")
plots.append(p1)
n_cols = 4
# Split into rows
grid = [plots[i:i+n_cols] for i in range(0, len(plots), n_cols)]
# Display
show(gridplot(grid))
- alpha distribution are not heavily skewed
- x_seq_obs is skewed towards 0, with tails up to 500000
- s_seq_obs matches the input
- b_x is skewed towards 0, tails up to 7
- b_s skewed toweards 0, tails up to 10
Look at alpha_global
In [135]:
prior_samples = prior_pred.prior['alpha_global'].values
plots = []
for gene_ind in range(G):
prior_gene = prior_samples[0, :, gene_ind] # get prior samples for this replicate and gene
p1 = figure(title=f"Prior Predictive Check, alpha global, Gene {gene_ind}",
x_axis_label="alpha-global",y_axis_label="Count",
width=400,height=300,background_fill_color="#fafafa")
hist, edges = np.histogram(prior_gene, bins=50)
p1.quad(
top=hist,bottom=0,left=edges[:-1],right=edges[1:],
fill_color="skyblue",line_color="black", legend_label="1000 random samples")
plots.append(p1)
n_cols = 4
# Split into rows
grid = [plots[i:i+n_cols] for i in range(0, len(plots), n_cols)]
# Display
show(gridplot(grid))
- alpha_global not skewed too heavily
Notes:
- Using logNormal for b_x and b_s gives highly skewed distributions
Functions to simulate data from prior and run the model¶
In [151]:
def simulate_data(mu_b_mu, X_mrna_mu, X_mrna_sigma, conc, G=3, K=4, R=2):
"""Simulate data for the hierarchical model."""
mu_b = np.abs(np.random.normal(0, mu_b_mu))
b_x_true = np.abs(np.random.normal(0, mu_b, size=(R, G)))
b_s_true = np.abs(np.random.normal(0, mu_b, size=(R, K)))
alpha_global_true = np.random.dirichlet(alpha=np.ones(G))
alpha_true = np.random.dirichlet(alpha=conc * alpha_global_true, size=R)
X_mrna_true = np.random.lognormal(X_mrna_mu, X_mrna_sigma, size=R)
x_mrna_true = np.array([
np.random.multinomial(n=total, pvals=alpha_true[r])
for r, total in enumerate(X_mrna_true)
]) #x_mrna_true = X_mrna_true[:, None] * alpha_true
s_spike_true = np.array([[2, 20, 1000, 2000], [2, 20, 1000, 2000]])
x_seq_true = np.random.poisson(lam=b_x_true * x_mrna_true)
s_seq_true = np.random.poisson(lam=b_s_true * s_spike_true)
true_params = {
"mu_b": mu_b,
"b_x": b_x_true,
"b_s": b_s_true,
"alpha_global": alpha_global_true,
"X_mrna": X_mrna_true,
"alpha": alpha_true,
}
return s_spike_true, x_seq_true, s_seq_true, true_params
def run_model(s_spike_true, x_seq_true, s_seq_true, G, K, R, mu_b_mu, X_mrna_mu, X_mrna_sigma, conc):
"""Run the PyMC model and return trace and predictions."""
with pm.Model() as model:
# Spike-in/transcriptome priors
mu_b = pm.HalfNormal("mu_b", sigma=mu_b_mu)
b_x = pm.HalfNormal('b_x', sigma=mu_b, shape=(R, G))
b_s = pm.HalfNormal("b_s", sigma=mu_b, shape=(R, K))
# Idealized priors
# Hyperprior global alpha
alpha_global = pm.Dirichlet('alpha_global', a=np.ones(G), shape=G)
alpha = pm.Dirichlet('alpha', a=conc * alpha_global, shape=(R, G))
# Likelihood spike ins
s_seq_obs = pm.Poisson("s_seq_obs", mu=s_spike_true * b_s, observed=s_seq_true)
# Likelihood transcriptome
X_mrna1 = pm.LogNormal("X_mrna1", mu=X_mrna_mu, sigma=X_mrna_sigma, shape=R)
X_mrna = pm.Deterministic("X_mrna", pm.math.maximum(pm.math.floor(X_mrna1), 1))
# x_mrna = X_mrna[:, None] * alpha
x_mrna = pm.Multinomial('x_mrna', n = X_mrna, p = alpha, shape = (R,G))
x_seq_obs = pm.Poisson("x_seq_obs", mu=x_mrna * b_x, observed=x_seq_true)
trace = pm.sample(1000, tune=1000, target_accept=0.90, progressbar = False)
prior_pred = pm.sample_prior_predictive(samples=1000)
ppc = pm.sample_posterior_predictive(trace)
return trace, prior_pred, ppc
Functions to compute rank and plot¶
In [139]:
def compute_ranks(true_vals, posterior_samples):
"""Compute ranks for SBC analysis."""
ranks = []
N_sbc = true_vals.shape[0]
for i in range(N_sbc):
# Flatten param posterior samples for ith simulation
posterior_i = posterior_samples[i] # shape (n_samples, ...)
# Flatten to 2D: (n_samples, param_dim)
posterior_i_flat = posterior_i.reshape(posterior_i.shape[0], -1)
true_i_flat = true_vals[i].reshape(-1)
ranks_i = []
for dim in range(len(true_i_flat)):
rank = np.sum(posterior_i_flat[:, dim] < true_i_flat[dim])
ranks_i.append(rank)
ranks.append(ranks_i)
return np.array(ranks) # shape (N_sbc, param_dim)
def plot_individual_posteriors(true_vals, posterior_samples, param_name):
"""Plot individual posterior distributions for each SBC run, rep, and dimension."""
N_sbc = len(posterior_samples)
plots = []
plot_titles = []
for sbc_run in range(N_sbc):
# Get shapes for this parameter
post_shape = posterior_samples[sbc_run].shape # (n_chains, n_samples, ...)
true_shape = true_vals[sbc_run].shape if hasattr(true_vals[sbc_run], 'shape') else ()
# Handle different parameter types
if len(post_shape) == 2: # Scalar parameter (n_chains, n_samples)
post_samples = posterior_samples[sbc_run].flatten()
true_val = true_vals[sbc_run]
p = figure(title=f'{param_name} - Run {sbc_run+1}', width=250, height=200)
hist, edges = np.histogram(post_samples, bins=30, density=True)
# Ensure arrays have same length
left_edges = edges[:-1]
right_edges = edges[1:]
p.quad(top=hist, bottom=0, left=left_edges, right=right_edges,
fill_alpha=0.7, line_color="black", fill_color="skyblue")
vline = Span(location=true_val, dimension='height', line_color='red',
line_dash='dashed', line_width=2)
p.add_layout(vline)
p.xaxis.axis_label = 'Value'
p.yaxis.axis_label = 'Density'
plots.append(p)
elif len(post_shape) == 3: # 1D parameter (n_chains, n_samples, dim)
for dim in range(post_shape[2]):
post_samples = posterior_samples[sbc_run][:, :, dim].flatten()
true_val = true_vals[sbc_run][dim]
p = figure(title=f'{param_name}[{dim}] - Run {sbc_run+1}', width=250, height=200)
hist, edges = np.histogram(post_samples, bins=30, density=True)
# Ensure arrays have same length
left_edges = edges[:-1]
right_edges = edges[1:]
p.quad(top=hist, bottom=0, left=left_edges, right=right_edges,
fill_alpha=0.7, line_color="black", fill_color="skyblue")
vline = Span(location=true_val, dimension='height', line_color='red',
line_dash='dashed', line_width=2)
p.add_layout(vline)
p.xaxis.axis_label = 'Value'
p.yaxis.axis_label = 'Density'
plots.append(p)
elif len(post_shape) == 4: # 2D parameter (n_chains, n_samples, dim1, dim2)
for rep in range(post_shape[2]):
for gene in range(post_shape[3]):
post_samples = posterior_samples[sbc_run][:, :, rep, gene].flatten()
true_val = true_vals[sbc_run][rep, gene]
p = figure(title=f'{param_name}[{rep},{gene}] - Run {sbc_run+1}',
width=250, height=200)
hist, edges = np.histogram(post_samples, bins=30, density=True)
# Ensure arrays have same length
left_edges = edges[:-1]
right_edges = edges[1:]
p.quad(top=hist, bottom=0, left=left_edges, right=right_edges,
fill_alpha=0.7, line_color="black", fill_color="skyblue")
vline = Span(location=true_val, dimension='height', line_color='red',
line_dash='dashed', line_width=2)
p.add_layout(vline)
p.xaxis.axis_label = 'Value'
p.yaxis.axis_label = 'Density'
plots.append(p)
# Arrange plots in grid
n_plots = len(plots)
if n_plots == 0:
print(f"No plots generated for {param_name}")
return
n_cols = min(4, n_plots)
grid_plots = []
for i in range(0, n_plots, n_cols):
row = plots[i:i+n_cols]
# Pad row with None if needed
while len(row) < n_cols:
row.append(None)
grid_plots.append(row)
grid = gridplot(grid_plots)
show(grid)
def plot_mean_vs_true(true_vals, posterior_samples, param_name):
"""Plot posterior means vs true values."""
mean_post = []
true_flat = []
for i in range(len(posterior_samples)):
post_mean = posterior_samples[i].mean(axis=0)
true_val = true_vals[i]
if np.isscalar(post_mean):
mean_post.append(post_mean)
true_flat.append(true_val)
else:
mean_post.extend(post_mean.flatten())
true_flat.extend(true_val.flatten() if hasattr(true_val, 'flatten') else [true_val])
p = figure(title=f'Mean Posterior vs True for {param_name}', width=400, height=400)
p.scatter(true_flat, mean_post, alpha=0.6, size=8)
# Add diagonal line
min_val = min(min(true_flat), min(mean_post))
max_val = max(max(true_flat), max(mean_post))
p.line([min_val, max_val], [min_val, max_val], color="red", line_dash="dashed",
line_width=2, legend_label='Perfect Recovery')
p.xaxis.axis_label = 'True Value'
p.yaxis.axis_label = 'Posterior Mean'
p.legend.location = "top_left"
show(p)
def plot_rank_histogram(ranks, param_name):
"""Plot rank histogram for SBC analysis."""
ranks_flat = ranks.flatten()
p = figure(title=f'Rank Histogram for {param_name}', width=600, height=300)
max_rank = int(np.max(ranks_flat))
hist, edges = np.histogram(ranks_flat, bins=max_rank + 1, range=(0, max_rank + 1))
# Ensure arrays have same length
left_edges = edges[:-1]
right_edges = edges[1:]
p.quad(top=hist, bottom=0, left=left_edges, right=right_edges,
fill_color="green", line_color="black", alpha=0.7)
p.xaxis.axis_label = 'Rank'
p.yaxis.axis_label = 'Frequency'
show(p)
def plot_all_for_param(true_vals, posterior_samples, param_name):
"""Plot all SBC diagnostics for a parameter."""
print(f"\n{'='*50}")
print(f"SBC Results for parameter: {param_name}")
print(f"{'='*50}")
# Individual posterior distributions
#print("\n1. Individual Posterior Distributions:")
#plot_individual_posteriors(true_vals, posterior_samples, param_name)
# Mean vs true scatter plot
print("\n1. Posterior Mean vs True Value:")
plot_mean_vs_true(true_vals, posterior_samples, param_name)
# Rank histogram
print("\n2. Rank Histogram:")
ranks = compute_ranks(true_vals, posterior_samples)
plot_rank_histogram(ranks, param_name)
Function to run entire SBC¶
In [140]:
def run_sbc(mu_b_mu, X_mrna_mu, X_mrna_sigma, conc, N_sbc=20, G=3, K=4, R=2):
"""Run Simulation-Based Calibration."""
true_params_list = {
'mu_b': [],
'b_x': [],
'b_s': [],
'alpha_global': [],
'X_mrna': [],
'alpha': [],
}
posterior_samples = {
'mu_b': [],
'b_x': [],
'b_s': [],
'alpha_global': [],
'X_mrna': [],
'alpha': [],
}
for i in tqdm(range(N_sbc)):
s_spike_true, x_seq_true, s_seq_true, true_param_dict = simulate_data(mu_b_mu, X_mrna_mu, X_mrna_sigma, conc, G, K, R)
trace, prior_pred, ppc = run_model(s_spike_true, x_seq_true, s_seq_true, G, K, R, mu_b_mu, X_mrna_mu, X_mrna_sigma, conc)
# Extract posterior samples for each parameter
for param in posterior_samples.keys():
posterior_samples[param].append(trace.posterior[param].values)
true_params_list[param].append(true_param_dict[param])
# Convert lists to arrays
for param in posterior_samples.keys():
posterior_samples[param] = np.array(posterior_samples[param])
true_params_list[param] = np.array(true_params_list[param])
return true_params_list, posterior_samples
Run entire SBC and plot¶
In [141]:
# Suppress specific warning message
import warnings
warnings.filterwarnings("ignore", message='install "ipywidgets" for Jupyter support')
In [155]:
mu_b_mu = 3
X_mrna_mu = np.log(2500)
X_mrna_sigma = 1
conc = 500
N_sbc = 30
np.random.seed(42)
# Run SBC to get true params and posterior samples
print("Running SBC analysis...")
true_params, posterior_samples = run_sbc(mu_b_mu, X_mrna_mu, X_mrna_sigma, conc, N_sbc, G, K, R)
Running SBC analysis...
Multiprocess sampling (4 chains in 4 jobs) | 0/30 [00:00<?, ?it/s] CompoundStep >NUTS: [mu_b, b_x, b_s, alpha_global, alpha, X_mrna1] >Metropolis: [x_mrna] Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 7 seconds. There were 3566 divergences after tuning. Increase `target_accept` or reparameterize. The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details The effective sample size per chain is smaller than 100 for some parameters. A higher number is needed for reliable rhat and ess computation. See https://arxiv.org/abs/1903.08008 for details Sampling: [X_mrna1, alpha, alpha_global, b_s, b_x, mu_b, s_seq_obs, x_mrna, x_seq_obs] Sampling: [s_seq_obs, x_seq_obs]
Multiprocess sampling (4 chains in 4 jobs) | 1/30 [00:08<04:12, 8.70s/it] CompoundStep >NUTS: [mu_b, b_x, b_s, alpha_global, alpha, X_mrna1] >Metropolis: [x_mrna] Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 7 seconds. There were 3577 divergences after tuning. Increase `target_accept` or reparameterize. The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details The effective sample size per chain is smaller than 100 for some parameters. A higher number is needed for reliable rhat and ess computation. See https://arxiv.org/abs/1903.08008 for details Sampling: [X_mrna1, alpha, alpha_global, b_s, b_x, mu_b, s_seq_obs, x_mrna, x_seq_obs] Sampling: [s_seq_obs, x_seq_obs]
Multiprocess sampling (4 chains in 4 jobs) | 2/30 [00:17<04:01, 8.62s/it] CompoundStep >NUTS: [mu_b, b_x, b_s, alpha_global, alpha, X_mrna1] >Metropolis: [x_mrna] Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 8 seconds. There were 3564 divergences after tuning. Increase `target_accept` or reparameterize. The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details The effective sample size per chain is smaller than 100 for some parameters. A higher number is needed for reliable rhat and ess computation. See https://arxiv.org/abs/1903.08008 for details Sampling: [X_mrna1, alpha, alpha_global, b_s, b_x, mu_b, s_seq_obs, x_mrna, x_seq_obs] Sampling: [s_seq_obs, x_seq_obs]
Multiprocess sampling (4 chains in 4 jobs) | 3/30 [00:29<04:38, 10.32s/it] CompoundStep >NUTS: [mu_b, b_x, b_s, alpha_global, alpha, X_mrna1] >Metropolis: [x_mrna] Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 9 seconds. There were 3579 divergences after tuning. Increase `target_accept` or reparameterize. The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details The effective sample size per chain is smaller than 100 for some parameters. A higher number is needed for reliable rhat and ess computation. See https://arxiv.org/abs/1903.08008 for details Sampling: [X_mrna1, alpha, alpha_global, b_s, b_x, mu_b, s_seq_obs, x_mrna, x_seq_obs] Sampling: [s_seq_obs, x_seq_obs]
Multiprocess sampling (4 chains in 4 jobs) | 4/30 [00:41<04:41, 10.81s/it] CompoundStep >NUTS: [mu_b, b_x, b_s, alpha_global, alpha, X_mrna1] >Metropolis: [x_mrna] Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 11 seconds. There were 3566 divergences after tuning. Increase `target_accept` or reparameterize. The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details The effective sample size per chain is smaller than 100 for some parameters. A higher number is needed for reliable rhat and ess computation. See https://arxiv.org/abs/1903.08008 for details Sampling: [X_mrna1, alpha, alpha_global, b_s, b_x, mu_b, s_seq_obs, x_mrna, x_seq_obs] Sampling: [s_seq_obs, x_seq_obs]
Multiprocess sampling (4 chains in 4 jobs) | 5/30 [00:54<04:48, 11.54s/it] CompoundStep >NUTS: [mu_b, b_x, b_s, alpha_global, alpha, X_mrna1] >Metropolis: [x_mrna] Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 12 seconds. There were 3607 divergences after tuning. Increase `target_accept` or reparameterize. The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details The effective sample size per chain is smaller than 100 for some parameters. A higher number is needed for reliable rhat and ess computation. See https://arxiv.org/abs/1903.08008 for details Sampling: [X_mrna1, alpha, alpha_global, b_s, b_x, mu_b, s_seq_obs, x_mrna, x_seq_obs] Sampling: [s_seq_obs, x_seq_obs]
Multiprocess sampling (4 chains in 4 jobs) | 6/30 [01:08<05:00, 12.53s/it] CompoundStep >NUTS: [mu_b, b_x, b_s, alpha_global, alpha, X_mrna1] >Metropolis: [x_mrna] Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 9 seconds. There were 3664 divergences after tuning. Increase `target_accept` or reparameterize. The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details The effective sample size per chain is smaller than 100 for some parameters. A higher number is needed for reliable rhat and ess computation. See https://arxiv.org/abs/1903.08008 for details Sampling: [X_mrna1, alpha, alpha_global, b_s, b_x, mu_b, s_seq_obs, x_mrna, x_seq_obs] Sampling: [s_seq_obs, x_seq_obs]
Multiprocess sampling (4 chains in 4 jobs) | 7/30 [01:21<04:51, 12.67s/it] CompoundStep >NUTS: [mu_b, b_x, b_s, alpha_global, alpha, X_mrna1] >Metropolis: [x_mrna] Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 9 seconds. There were 3591 divergences after tuning. Increase `target_accept` or reparameterize. The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details The effective sample size per chain is smaller than 100 for some parameters. A higher number is needed for reliable rhat and ess computation. See https://arxiv.org/abs/1903.08008 for details Sampling: [X_mrna1, alpha, alpha_global, b_s, b_x, mu_b, s_seq_obs, x_mrna, x_seq_obs] Sampling: [s_seq_obs, x_seq_obs]
Multiprocess sampling (4 chains in 4 jobs) | 8/30 [01:32<04:29, 12.23s/it] CompoundStep >NUTS: [mu_b, b_x, b_s, alpha_global, alpha, X_mrna1] >Metropolis: [x_mrna] Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 9 seconds. There were 3654 divergences after tuning. Increase `target_accept` or reparameterize. The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details The effective sample size per chain is smaller than 100 for some parameters. A higher number is needed for reliable rhat and ess computation. See https://arxiv.org/abs/1903.08008 for details Sampling: [X_mrna1, alpha, alpha_global, b_s, b_x, mu_b, s_seq_obs, x_mrna, x_seq_obs] Sampling: [s_seq_obs, x_seq_obs]
Multiprocess sampling (4 chains in 4 jobs) | 9/30 [01:43<04:06, 11.75s/it] CompoundStep >NUTS: [mu_b, b_x, b_s, alpha_global, alpha, X_mrna1] >Metropolis: [x_mrna] Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 10 seconds. There were 3613 divergences after tuning. Increase `target_accept` or reparameterize. The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details The effective sample size per chain is smaller than 100 for some parameters. A higher number is needed for reliable rhat and ess computation. See https://arxiv.org/abs/1903.08008 for details Sampling: [X_mrna1, alpha, alpha_global, b_s, b_x, mu_b, s_seq_obs, x_mrna, x_seq_obs] Sampling: [s_seq_obs, x_seq_obs]
Multiprocess sampling (4 chains in 4 jobs) | 10/30 [01:55<03:57, 11.88s/it] CompoundStep >NUTS: [mu_b, b_x, b_s, alpha_global, alpha, X_mrna1] >Metropolis: [x_mrna] Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 7 seconds. There were 3495 divergences after tuning. Increase `target_accept` or reparameterize. The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details The effective sample size per chain is smaller than 100 for some parameters. A higher number is needed for reliable rhat and ess computation. See https://arxiv.org/abs/1903.08008 for details Sampling: [X_mrna1, alpha, alpha_global, b_s, b_x, mu_b, s_seq_obs, x_mrna, x_seq_obs] Sampling: [s_seq_obs, x_seq_obs]
Multiprocess sampling (4 chains in 4 jobs) | 11/30 [02:04<03:28, 10.97s/it] CompoundStep >NUTS: [mu_b, b_x, b_s, alpha_global, alpha, X_mrna1] >Metropolis: [x_mrna] Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 7 seconds. There were 3570 divergences after tuning. Increase `target_accept` or reparameterize. The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details The effective sample size per chain is smaller than 100 for some parameters. A higher number is needed for reliable rhat and ess computation. See https://arxiv.org/abs/1903.08008 for details Sampling: [X_mrna1, alpha, alpha_global, b_s, b_x, mu_b, s_seq_obs, x_mrna, x_seq_obs] Sampling: [s_seq_obs, x_seq_obs]
Multiprocess sampling (4 chains in 4 jobs) | 12/30 [02:16<03:21, 11.17s/it] CompoundStep >NUTS: [mu_b, b_x, b_s, alpha_global, alpha, X_mrna1] >Metropolis: [x_mrna] Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 7 seconds. There were 3542 divergences after tuning. Increase `target_accept` or reparameterize. The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details The effective sample size per chain is smaller than 100 for some parameters. A higher number is needed for reliable rhat and ess computation. See https://arxiv.org/abs/1903.08008 for details Sampling: [X_mrna1, alpha, alpha_global, b_s, b_x, mu_b, s_seq_obs, x_mrna, x_seq_obs] Sampling: [s_seq_obs, x_seq_obs]
Multiprocess sampling (4 chains in 4 jobs) | 13/30 [02:25<03:01, 10.69s/it] CompoundStep >NUTS: [mu_b, b_x, b_s, alpha_global, alpha, X_mrna1] >Metropolis: [x_mrna] Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 8 seconds. There were 3520 divergences after tuning. Increase `target_accept` or reparameterize. The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details The effective sample size per chain is smaller than 100 for some parameters. A higher number is needed for reliable rhat and ess computation. See https://arxiv.org/abs/1903.08008 for details Sampling: [X_mrna1, alpha, alpha_global, b_s, b_x, mu_b, s_seq_obs, x_mrna, x_seq_obs] Sampling: [s_seq_obs, x_seq_obs]
Multiprocess sampling (4 chains in 4 jobs) | 14/30 [02:35<02:45, 10.35s/it] CompoundStep >NUTS: [mu_b, b_x, b_s, alpha_global, alpha, X_mrna1] >Metropolis: [x_mrna] Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 8 seconds. There were 3540 divergences after tuning. Increase `target_accept` or reparameterize. The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details The effective sample size per chain is smaller than 100 for some parameters. A higher number is needed for reliable rhat and ess computation. See https://arxiv.org/abs/1903.08008 for details Sampling: [X_mrna1, alpha, alpha_global, b_s, b_x, mu_b, s_seq_obs, x_mrna, x_seq_obs] Sampling: [s_seq_obs, x_seq_obs]
Multiprocess sampling (4 chains in 4 jobs) | 15/30 [02:45<02:33, 10.25s/it] CompoundStep >NUTS: [mu_b, b_x, b_s, alpha_global, alpha, X_mrna1] >Metropolis: [x_mrna] Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 9 seconds. There were 3590 divergences after tuning. Increase `target_accept` or reparameterize. The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details The effective sample size per chain is smaller than 100 for some parameters. A higher number is needed for reliable rhat and ess computation. See https://arxiv.org/abs/1903.08008 for details Sampling: [X_mrna1, alpha, alpha_global, b_s, b_x, mu_b, s_seq_obs, x_mrna, x_seq_obs] Sampling: [s_seq_obs, x_seq_obs]
Multiprocess sampling (4 chains in 4 jobs) | 16/30 [02:58<02:35, 11.10s/it] CompoundStep >NUTS: [mu_b, b_x, b_s, alpha_global, alpha, X_mrna1] >Metropolis: [x_mrna] Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 8 seconds. There were 3549 divergences after tuning. Increase `target_accept` or reparameterize. The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details The effective sample size per chain is smaller than 100 for some parameters. A higher number is needed for reliable rhat and ess computation. See https://arxiv.org/abs/1903.08008 for details Sampling: [X_mrna1, alpha, alpha_global, b_s, b_x, mu_b, s_seq_obs, x_mrna, x_seq_obs] Sampling: [s_seq_obs, x_seq_obs]
Multiprocess sampling (4 chains in 4 jobs) | 17/30 [03:08<02:21, 10.90s/it] CompoundStep >NUTS: [mu_b, b_x, b_s, alpha_global, alpha, X_mrna1] >Metropolis: [x_mrna] Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 8 seconds. There were 3566 divergences after tuning. Increase `target_accept` or reparameterize. The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details The effective sample size per chain is smaller than 100 for some parameters. A higher number is needed for reliable rhat and ess computation. See https://arxiv.org/abs/1903.08008 for details Sampling: [X_mrna1, alpha, alpha_global, b_s, b_x, mu_b, s_seq_obs, x_mrna, x_seq_obs] Sampling: [s_seq_obs, x_seq_obs]
Multiprocess sampling (4 chains in 4 jobs) | 18/30 [03:19<02:09, 10.80s/it] CompoundStep >NUTS: [mu_b, b_x, b_s, alpha_global, alpha, X_mrna1] >Metropolis: [x_mrna] Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 7 seconds. There were 3554 divergences after tuning. Increase `target_accept` or reparameterize. The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details The effective sample size per chain is smaller than 100 for some parameters. A higher number is needed for reliable rhat and ess computation. See https://arxiv.org/abs/1903.08008 for details Sampling: [X_mrna1, alpha, alpha_global, b_s, b_x, mu_b, s_seq_obs, x_mrna, x_seq_obs] Sampling: [s_seq_obs, x_seq_obs]
Multiprocess sampling (4 chains in 4 jobs) | 19/30 [03:30<01:59, 10.87s/it] CompoundStep >NUTS: [mu_b, b_x, b_s, alpha_global, alpha, X_mrna1] >Metropolis: [x_mrna] Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 10 seconds. There were 3615 divergences after tuning. Increase `target_accept` or reparameterize. The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details The effective sample size per chain is smaller than 100 for some parameters. A higher number is needed for reliable rhat and ess computation. See https://arxiv.org/abs/1903.08008 for details Sampling: [X_mrna1, alpha, alpha_global, b_s, b_x, mu_b, s_seq_obs, x_mrna, x_seq_obs] Sampling: [s_seq_obs, x_seq_obs]
Multiprocess sampling (4 chains in 4 jobs) | 20/30 [03:42<01:53, 11.30s/it] CompoundStep >NUTS: [mu_b, b_x, b_s, alpha_global, alpha, X_mrna1] >Metropolis: [x_mrna] Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 9 seconds. There were 3585 divergences after tuning. Increase `target_accept` or reparameterize. The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details The effective sample size per chain is smaller than 100 for some parameters. A higher number is needed for reliable rhat and ess computation. See https://arxiv.org/abs/1903.08008 for details Sampling: [X_mrna1, alpha, alpha_global, b_s, b_x, mu_b, s_seq_obs, x_mrna, x_seq_obs] Sampling: [s_seq_obs, x_seq_obs]
Multiprocess sampling (4 chains in 4 jobs) | 21/30 [03:53<01:39, 11.08s/it] CompoundStep >NUTS: [mu_b, b_x, b_s, alpha_global, alpha, X_mrna1] >Metropolis: [x_mrna] Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 7 seconds. There were 3579 divergences after tuning. Increase `target_accept` or reparameterize. The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details The effective sample size per chain is smaller than 100 for some parameters. A higher number is needed for reliable rhat and ess computation. See https://arxiv.org/abs/1903.08008 for details Sampling: [X_mrna1, alpha, alpha_global, b_s, b_x, mu_b, s_seq_obs, x_mrna, x_seq_obs] Sampling: [s_seq_obs, x_seq_obs]
Multiprocess sampling (4 chains in 4 jobs) | 22/30 [04:02<01:23, 10.45s/it] CompoundStep >NUTS: [mu_b, b_x, b_s, alpha_global, alpha, X_mrna1] >Metropolis: [x_mrna] Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 9 seconds. There were 3611 divergences after tuning. Increase `target_accept` or reparameterize. The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details The effective sample size per chain is smaller than 100 for some parameters. A higher number is needed for reliable rhat and ess computation. See https://arxiv.org/abs/1903.08008 for details Sampling: [X_mrna1, alpha, alpha_global, b_s, b_x, mu_b, s_seq_obs, x_mrna, x_seq_obs] Sampling: [s_seq_obs, x_seq_obs]
Multiprocess sampling (4 chains in 4 jobs) | 23/30 [04:13<01:14, 10.62s/it] CompoundStep >NUTS: [mu_b, b_x, b_s, alpha_global, alpha, X_mrna1] >Metropolis: [x_mrna] Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 7 seconds. There were 3524 divergences after tuning. Increase `target_accept` or reparameterize. The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details The effective sample size per chain is smaller than 100 for some parameters. A higher number is needed for reliable rhat and ess computation. See https://arxiv.org/abs/1903.08008 for details Sampling: [X_mrna1, alpha, alpha_global, b_s, b_x, mu_b, s_seq_obs, x_mrna, x_seq_obs] Sampling: [s_seq_obs, x_seq_obs]
Multiprocess sampling (4 chains in 4 jobs) | 24/30 [04:24<01:04, 10.73s/it] CompoundStep >NUTS: [mu_b, b_x, b_s, alpha_global, alpha, X_mrna1] >Metropolis: [x_mrna] Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 10 seconds. There were 3542 divergences after tuning. Increase `target_accept` or reparameterize. The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details The effective sample size per chain is smaller than 100 for some parameters. A higher number is needed for reliable rhat and ess computation. See https://arxiv.org/abs/1903.08008 for details Sampling: [X_mrna1, alpha, alpha_global, b_s, b_x, mu_b, s_seq_obs, x_mrna, x_seq_obs] Sampling: [s_seq_obs, x_seq_obs]
Multiprocess sampling (4 chains in 4 jobs) | 25/30 [04:36<00:56, 11.25s/it] CompoundStep >NUTS: [mu_b, b_x, b_s, alpha_global, alpha, X_mrna1] >Metropolis: [x_mrna] Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 10 seconds. There were 3610 divergences after tuning. Increase `target_accept` or reparameterize. The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details The effective sample size per chain is smaller than 100 for some parameters. A higher number is needed for reliable rhat and ess computation. See https://arxiv.org/abs/1903.08008 for details Sampling: [X_mrna1, alpha, alpha_global, b_s, b_x, mu_b, s_seq_obs, x_mrna, x_seq_obs] Sampling: [s_seq_obs, x_seq_obs]
Multiprocess sampling (4 chains in 4 jobs) | 26/30 [04:49<00:46, 11.66s/it] CompoundStep >NUTS: [mu_b, b_x, b_s, alpha_global, alpha, X_mrna1] >Metropolis: [x_mrna] Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 10 seconds. There were 3642 divergences after tuning. Increase `target_accept` or reparameterize. The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details The effective sample size per chain is smaller than 100 for some parameters. A higher number is needed for reliable rhat and ess computation. See https://arxiv.org/abs/1903.08008 for details Sampling: [X_mrna1, alpha, alpha_global, b_s, b_x, mu_b, s_seq_obs, x_mrna, x_seq_obs] Sampling: [s_seq_obs, x_seq_obs]
Multiprocess sampling (4 chains in 4 jobs) | 27/30 [05:03<00:36, 12.33s/it] CompoundStep >NUTS: [mu_b, b_x, b_s, alpha_global, alpha, X_mrna1] >Metropolis: [x_mrna] Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 7 seconds. There were 3652 divergences after tuning. Increase `target_accept` or reparameterize. The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details The effective sample size per chain is smaller than 100 for some parameters. A higher number is needed for reliable rhat and ess computation. See https://arxiv.org/abs/1903.08008 for details Sampling: [X_mrna1, alpha, alpha_global, b_s, b_x, mu_b, s_seq_obs, x_mrna, x_seq_obs] Sampling: [s_seq_obs, x_seq_obs]
Multiprocess sampling (4 chains in 4 jobs)█▎ | 28/30 [05:12<00:22, 11.48s/it] CompoundStep >NUTS: [mu_b, b_x, b_s, alpha_global, alpha, X_mrna1] >Metropolis: [x_mrna] Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 9 seconds. There were 3579 divergences after tuning. Increase `target_accept` or reparameterize. The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details The effective sample size per chain is smaller than 100 for some parameters. A higher number is needed for reliable rhat and ess computation. See https://arxiv.org/abs/1903.08008 for details Sampling: [X_mrna1, alpha, alpha_global, b_s, b_x, mu_b, s_seq_obs, x_mrna, x_seq_obs] Sampling: [s_seq_obs, x_seq_obs]
Multiprocess sampling (4 chains in 4 jobs)██▋ | 29/30 [05:24<00:11, 11.45s/it] CompoundStep >NUTS: [mu_b, b_x, b_s, alpha_global, alpha, X_mrna1] >Metropolis: [x_mrna] Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 7 seconds. There were 3505 divergences after tuning. Increase `target_accept` or reparameterize. The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details The effective sample size per chain is smaller than 100 for some parameters. A higher number is needed for reliable rhat and ess computation. See https://arxiv.org/abs/1903.08008 for details Sampling: [X_mrna1, alpha, alpha_global, b_s, b_x, mu_b, s_seq_obs, x_mrna, x_seq_obs] Sampling: [s_seq_obs, x_seq_obs]
100%|█████████████████████████████████████████| 30/30 [05:32<00:00, 11.10s/it]
In [ ]:
# import pickle
# # save results pickle
# filename = 'true_params, posterior_samples, run2'
# if not filename.endswith('.pkl'):
# filename += '.pkl'
# sbc_data = {'true_params': true_params,
# 'posterior_samples':posterior_samples}
# with open(filename, 'wb') as f:
# pickle.dump(sbc_data,f)
In [156]:
# For each parameter, plot SBC results
for param in true_params.keys():
plot_all_for_param(true_params[param], posterior_samples[param], param)
================================================== SBC Results for parameter: mu_b ================================================== 1. Posterior Mean vs True Value:
BokehUserWarning: ColumnDataSource's columns must be of the same length. Current lengths: ('x', 30), ('y', 30000)
2. Rank Histogram:
================================================== SBC Results for parameter: b_x ================================================== 1. Posterior Mean vs True Value:
BokehUserWarning: ColumnDataSource's columns must be of the same length. Current lengths: ('x', 180), ('y', 180000)
2. Rank Histogram:
================================================== SBC Results for parameter: b_s ================================================== 1. Posterior Mean vs True Value:
BokehUserWarning: ColumnDataSource's columns must be of the same length. Current lengths: ('x', 240), ('y', 240000)
2. Rank Histogram:
================================================== SBC Results for parameter: alpha_global ================================================== 1. Posterior Mean vs True Value:
BokehUserWarning: ColumnDataSource's columns must be of the same length. Current lengths: ('x', 90), ('y', 90000)
2. Rank Histogram:
================================================== SBC Results for parameter: X_mrna ================================================== 1. Posterior Mean vs True Value:
BokehUserWarning: ColumnDataSource's columns must be of the same length. Current lengths: ('x', 60), ('y', 60000)
2. Rank Histogram:
================================================== SBC Results for parameter: alpha ================================================== 1. Posterior Mean vs True Value:
BokehUserWarning: ColumnDataSource's columns must be of the same length. Current lengths: ('x', 180), ('y', 180000)
2. Rank Histogram:
In [ ]:
In [ ]:
In [ ]:
%load_ext watermark
%watermark -v -p numpy,bokeh,pymc,pandas,tqdm,warnings,jupyterlab